{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "parking_her.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    },
    "kernelspec": {
      "name": "python3",
      "language": "python",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "pycharm": {
      "stem_cell": {
        "cell_type": "raw",
        "source": [],
        "metadata": {
          "collapsed": false
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5eeje4O8fviH",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# Parking with Hindsight Experience Replay\n",
        "\n",
        "##  Warming up\n",
        "We start with a few useful installs and imports:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bzMSuJEOfviP",
        "pycharm": {
          "is_executing": false,
          "name": "#%%\n"
        }
      },
      "source": [
        "# Install environment and agent\n",
        "!pip install highway-env\n",
        "!pip install stable-baselines3\n",
        "!pip install sb3-contrib\n",
        "\n",
        "# Environment\n",
        "import gym\n",
        "import highway_env\n",
        "\n",
        "# Agent\n",
        "from stable_baselines3 import HerReplayBuffer, SAC\n",
        "from sb3_contrib import TQC"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "_wACJRDjqP-f"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "pycharm": {
          "name": "#%% \n"
        },
        "id": "Y5TOvonYqP-g"
      },
      "source": [
        "env = gym.make(\"parking-v0\")\n",
        "her_kwargs = dict(n_sampled_goal=4, goal_selection_strategy='future', online_sampling=True, max_episode_length=100)\n",
        "# You can replace TQC with SAC agent\n",
        "model = TQC('MultiInputPolicy', env, replay_buffer_class=HerReplayBuffer,\n",
        "            replay_buffer_kwargs=her_kwargs, verbose=1, buffer_size=int(1e6),\n",
        "            learning_rate=1e-3,\n",
        "            gamma=0.95, batch_size=1024, tau=0.05,\n",
        "            policy_kwargs=dict(net_arch=[512, 512, 512]))\n",
        "model.learn(int(5e4))\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n2Bu_Pqop0E7"
      },
      "source": [
        "## Visualize a few episodes\n",
        "\n",
        "We first define a simple helper function for visualization of episodes:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "so7yH4ucyB-3"
      },
      "source": [
        "!pip install gym pyvirtualdisplay\n",
        "!apt-get install -y xvfb python-opengl ffmpeg\n",
        "\n",
        "from IPython import display as ipythondisplay\n",
        "from pyvirtualdisplay import Display\n",
        "from gym.wrappers import Monitor\n",
        "from pathlib import Path\n",
        "import base64\n",
        "from tqdm.notebook import trange\n",
        "\n",
        "display = Display(visible=0, size=(1400, 900))\n",
        "display.start()\n",
        "\n",
        "def show_video():\n",
        "    html = []\n",
        "    for mp4 in Path(\"video\").glob(\"*.mp4\"):\n",
        "        video_b64 = base64.b64encode(mp4.read_bytes())\n",
        "        html.append('''<video alt=\"{}\" autoplay \n",
        "                      loop controls style=\"height: 400px;\">\n",
        "                      <source src=\"data:video/mp4;base64,{}\" type=\"video/mp4\" />\n",
        "                 </video>'''.format(mp4, video_b64.decode('ascii')))\n",
        "    ipythondisplay.display(ipythondisplay.HTML(data=\"<br>\".join(html)))\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_8L6vEPWyea7"
      },
      "source": [
        "\n",
        "Test the policy"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xOcOP7Of18T2",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "source": [
        "env = gym.make(\"parking-v0\")\n",
        "env = Monitor(env, './video', force=True, video_callable=lambda episode: True)\n",
        "for episode in trange(3, desc=\"Test episodes\"):\n",
        "    obs, done = env.reset(), False\n",
        "    env.unwrapped.automatic_rendering_callback = env.video_recorder.capture_frame\n",
        "    while not done:\n",
        "        action, _ = model.predict(obs, deterministic=True)\n",
        "        obs, reward, done, info = env.step(action)\n",
        "env.close()\n",
        "show_video()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jh5h1ShBKYss"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}